{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0364fa99-3cad-4c11-ac41-6523fb98d187",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: HF_ENDPOINT=https://hf-mirror.com\n"
     ]
    }
   ],
   "source": [
    "%env HF_ENDPOINT=https://hf-mirror.com"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c654b825-84fd-43df-8412-53b1f9ecb8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# 设置 HF_HOME 环境变量 设置下载路径\n",
    "os.environ['HF_HOME'] = '/data1/ckw'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f30fc135-f12f-43bd-96e3-7ab02ef91296",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install jieba -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e9e91c93-9b06-4cff-b826-02d1f4fecc5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.932 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "from tokenization_gptpangu import GPTPanguTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "94abdb98-fb74-42c0-805b-03df9fd12311",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# from transformers import AutoTokenizer\n",
    "from modeling_gptpangu import GPTPanguForCausalLM\n",
    "\n",
    "model = GPTPanguForCausalLM.from_pretrained(\"sunzeyeah/pangu-350M-sft\")#trust_remote_code=True\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\"Apple/OpenELM-270M-Instruct\")Llama-2-7b-hf\n",
    "tokenizer = GPTPanguTokenizer.from_pretrained(\"sunzeyeah/pangu-350M-sft\")\n",
    "prompt = '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "\n",
    "# Generate\n",
    "generate_ids = model.generate(inputs.input_ids, max_length=300)\n",
    "tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "09bf8f6e-8c64-4c32-b289-71aa897a9b3f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里?'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = \"中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里？\"\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "\n",
    "# Generate\n",
    "generate_ids = model.generate(inputs.input_ids, max_length=300)\n",
    "tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d9eff78c-7abf-4b05-9335-286f789fbaf0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   1,   96,   22,  337,   22,  691,   22, 3204,   22, 4672,   22, 6605,\n",
       "           11, 6539, 1249,   16, 1329,   28,    9]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.input_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a554f163-4226-476e-b8e1-5efe45b7988c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   1,   96,   22,  337,   22,  691,   22, 3204,   22, 4672,   22, 6605,\n",
       "           11, 6539, 1249,   16, 1329,   28,    9,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,\n",
       "            6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6,    6]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8846ecb1-e912-49f2-8f80-acb6d3e5304b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:515: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:520: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,我也是这样,']\n"
     ]
    }
   ],
   "source": [
    "prompt = \"我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点<sep>回答：\"\n",
    "inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n",
    "outputs = model.generate(**inputs,\n",
    "                         max_new_tokens=100,\n",
    "                         pad_token_id=tokenizer.pad_token_id,\n",
    "                         do_sample=False,\n",
    "                         num_return_sequences=1,\n",
    "                         top_p=0.8,\n",
    "                         temperature=0.8)\n",
    "results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "results = [result.split(\"答:\", maxsplit=1)[1] for result in results]\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "065dc7a0-2efa-4d14-9130-e99720f4f98c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['美国和日本和法国和加拿大和澳大利亚的首都分别是华盛顿和纽约']\n"
     ]
    }
   ],
   "source": [
    "prompt = \"中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里？<sep>回答：\"\n",
    "inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n",
    "outputs = model.generate(**inputs,\n",
    "                         max_new_tokens=100,\n",
    "                         pad_token_id=tokenizer.pad_token_id,\n",
    "                         do_sample=False,\n",
    "                         num_return_sequences=1,\n",
    "                         top_p=0.8,\n",
    "                         temperature=0.8)\n",
    "results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "results = [result.split(\"答:\", maxsplit=1)[1] for result in results]\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2e5d28c2-3415-416e-817e-a596b766febe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里?<sep>回答:美国和日本和法国和加拿大和澳大利亚的首都分别是华盛顿和纽约']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.batch_decode(outputs, skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "26acd04e-1462-49c2-b0dc-234d0a82db73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Datawhale是一个数据库,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它是一个数据库管理系统,它']\n"
     ]
    }
   ],
   "source": [
    "prompt = \"你知道有关datawhale的信息么？<sep>回答：\"\n",
    "inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n",
    "outputs = model.generate(**inputs,\n",
    "                         max_new_tokens=100,\n",
    "                         pad_token_id=tokenizer.pad_token_id,\n",
    "                         do_sample=False,\n",
    "                         num_return_sequences=1,\n",
    "                         top_p=0.8,\n",
    "                         temperature=0.8)\n",
    "results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "results = [result.split(\"答:\", maxsplit=1)[1] for result in results]\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c503178-b46b-445d-9555-bb529acecb47",
   "metadata": {},
   "outputs": [],
   "source": [
    "Pangu-350M经过sft,只有符合指令才会有输出.同时,数据量较少,还是不能涵盖很多问题"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
